package tree

import (
	"fmt"
	"gitee.com/kklt1996/data-structure/common"
	"strings"
)

/*
	线段树的节点
*/
type segTreeNode struct {
	left, right         int
	value               interface{}
	leftNode, rightNode *segTreeNode
}

type SegmentTree struct {
	data []interface{}

	/*
		线段树的根
	*/
	root *segTreeNode

	/*
		聚合器,将左右子节点聚合成根节点
	*/
	aggregator func(leftChildValue interface{}, rightChildValue interface{}) interface{}
}

/*
	获取线段树区间大小
*/
func (tree SegmentTree) GetSize() int {
	return len(tree.data)
}

/*
	O(log(n))
	对区间[l,r]进行查询
*/
func (tree SegmentTree) Query(queryL int, queryR int) (interface{}, error) {
	if queryL < 0 || queryL > len(tree.data) || queryR < 0 || queryR > len(tree.data) || queryR < queryL {
		return nil, common.IndexError{}
	}
	return tree.query(tree.root, queryL, queryR), nil
}

func (tree SegmentTree) query(root *segTreeNode, queryL int, queryR int) interface{} {
	if root.left == queryL && root.right == queryR {
		// 找到要查询的范围
		return root.value
	} else {
		mid := root.left + (root.right-root.left)/2
		if queryR <= mid {
			// 要查询的范围在左子树
			return tree.query(root.leftNode, queryL, queryR)
		} else if queryL > mid {
			// 要查询的范围在右子树
			return tree.query(root.rightNode, queryL, queryR)
		} else {
			// 要查询的范围分布在左右子树
			leftSegValue := tree.query(root.leftNode, queryL, mid)
			rightSegValue := tree.query(root.rightNode, mid+1, queryR)
			return tree.aggregator(leftSegValue, rightSegValue)
		}
	}
}

/*
	O(log(n))
	更新线段树原数组中的值
*/
func (tree *SegmentTree) Set(i int, value interface{}) error {
	if i < 0 || i >= len(tree.data) {
		return common.IndexError{}
	}
	tree.data[i] = value
	// 对线段树进行更新操作
	tree.set(tree.root, i, value)
	return nil
}

/*
	在[l,r]范围内更新index位置元素为value后,更新线段树中范围包含index的节点的值
	返回更新
*/
func (tree *SegmentTree) set(root *segTreeNode, index int, value interface{}) {
	if root.left == root.right {
		root.value = tree.aggregator(value, nil)
	} else {
		mid := root.left + (root.right-root.left)/2
		if index <= mid {
			// 更新左线段的值
			tree.set(root.leftNode, index, value)
		} else {
			// 更新右线段的值
			tree.set(root.rightNode, index, value)
		}
		// 更新父亲节点范围的值
		root.value = tree.aggregator(root.leftNode.value, root.rightNode.value)
	}
}

/*
	初始化线段树
*/
func (tree *SegmentTree) init(l int, r int) *segTreeNode {
	if l == r {
		// 返回叶子节点
		root := segTreeNode{left: l, right: r, value: tree.aggregator(tree.data[l], nil)}
		return &root
	} else {
		mid := l + (r-l)/2
		root := segTreeNode{left: l, right: r}
		// 求出左子树
		root.leftNode = tree.init(l, mid)
		// 求出右子树
		root.rightNode = tree.init(mid+1, r)
		// 聚合出父亲节点的值
		root.value = tree.aggregator(root.leftNode.value, root.rightNode.value)
		return &root
	}
}

func (tree SegmentTree) preOrder(treeNode *segTreeNode, operatorFunc func(value interface{})) {
	if treeNode == nil {
		return
	}
	operatorFunc(treeNode)
	tree.preOrder(treeNode.leftNode, operatorFunc)
	tree.preOrder(treeNode.rightNode, operatorFunc)
}

func (tree SegmentTree) String() string {
	res := "["
	tree.preOrder(tree.root, func(value interface{}) {
		if value != nil {
			res += fmt.Sprintf("%v,", value)
		} else {
			res += "nil,"
		}
	})
	res = strings.TrimSuffix(res, ",")
	res += "]"
	return res
}
